{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import os\n",
    "\n",
    "## =====================================================================================\n",
    "## This is a temporarly fix for the freezing and the cuda issues. You can add this\n",
    "## utility script instead of kaggle_l5kit until Kaggle resolve these issues.\n",
    "## \n",
    "## You will be able to train and submit your results, but not all the functionality of\n",
    "## l5kit will work properly.\n",
    "\n",
    "## More details here:\n",
    "## https://www.kaggle.com/c/lyft-motion-prediction-autonomous-vehicles/discussion/177125\n",
    "\n",
    "## this script transports l5kit and dependencies\n",
    "os.system('pip install --target=/kaggle/working pymap3d==2.1.0')\n",
    "os.system('pip install --target=/kaggle/working protobuf==3.12.2')\n",
    "os.system('pip install --target=/kaggle/working transforms3d')\n",
    "os.system('pip install --target=/kaggle/working zarr')\n",
    "os.system('pip install --target=/kaggle/working ptable')\n",
    "\n",
    "os.system('pip install --no-dependencies --target=/kaggle/working l5kit')\n",
    "# os.system('pip install --target=/kaggle/working timm')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Requirement already satisfied: timm in /opt/conda/lib/python3.7/site-packages (0.3.1)\r\n",
      "Requirement already satisfied: torchvision in /opt/conda/lib/python3.7/site-packages (from timm) (0.8.1)\r\n",
      "Requirement already satisfied: torch>=1.0 in /opt/conda/lib/python3.7/site-packages (from timm) (1.7.0)\r\n",
      "Requirement already satisfied: future in /opt/conda/lib/python3.7/site-packages (from torch>=1.0->timm) (0.18.2)\r\n",
      "Requirement already satisfied: typing_extensions in /opt/conda/lib/python3.7/site-packages (from torch>=1.0->timm) (3.7.4.1)\r\n",
      "Requirement already satisfied: dataclasses in /opt/conda/lib/python3.7/site-packages (from torch>=1.0->timm) (0.6)\r\n",
      "Requirement already satisfied: numpy in /opt/conda/lib/python3.7/site-packages (from torch>=1.0->timm) (1.18.5)\r\n",
      "Requirement already satisfied: numpy in /opt/conda/lib/python3.7/site-packages (from torch>=1.0->timm) (1.18.5)\r\n",
      "Requirement already satisfied: torch>=1.0 in /opt/conda/lib/python3.7/site-packages (from timm) (1.7.0)\r\n",
      "Requirement already satisfied: pillow>=4.1.1 in /opt/conda/lib/python3.7/site-packages (from torchvision->timm) (8.0.1)\r\n"
     ]
    }
   ],
   "source": [
    "!pip install timm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {
    "_cell_guid": "79c7e3d0-c299-4dcb-8224-4455121ee9b0",
    "_uuid": "d629ff2d2480ee46fbb7e2d37f6b5fab8052498a"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['multi_mode_sample_submission.csv', 'semantic_map', 'aerial_map', 'single_mode_sample_submission.csv', 'meta.json', 'scenes']\n"
     ]
    }
   ],
   "source": [
    "# import packages\n",
    "import os, gc\n",
    "import zarr\n",
    "import numpy as np \n",
    "import pandas as pd \n",
    "from tqdm import tqdm\n",
    "from typing import Dict\n",
    "from collections import Counter\n",
    "from prettytable import PrettyTable\n",
    "from collections import OrderedDict\n",
    "import math\n",
    "import pickle\n",
    "from IPython.display import FileLink\n",
    "\n",
    "#level5 toolkit\n",
    "from l5kit.data import PERCEPTION_LABELS\n",
    "from l5kit.dataset import EgoDataset, AgentDataset\n",
    "from l5kit.data import ChunkedDataset, LocalDataManager\n",
    "\n",
    "# level5 toolkit \n",
    "from l5kit.configs import load_config_data\n",
    "from l5kit.geometry import transform_points\n",
    "from l5kit.rasterization import build_rasterizer\n",
    "from l5kit.visualization import draw_trajectory, draw_reference_trajectory, TARGET_POINTS_COLOR, PREDICTED_POINTS_COLOR, write_gif\n",
    "from l5kit.evaluation import write_pred_csv, compute_metrics_csv, read_gt_csv, create_chopped_dataset, export_zarr_to_csv, write_gt_csv\n",
    "from l5kit.evaluation.metrics import neg_multi_log_likelihood, time_displace\n",
    "\n",
    "# visualization\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib import animation\n",
    "from colorama import Fore, Back, Style\n",
    "\n",
    "# deep learning\n",
    "import torch\n",
    "from torch import nn, optim, Tensor\n",
    "from torch.utils.data import DataLoader\n",
    "from torchvision.models.resnet import resnet18, resnet50, resnet34\n",
    "from torchvision.models.mobilenet import mobilenet_v2\n",
    "from torch.nn import functional as F\n",
    "import timm\n",
    "\n",
    "# check files in directory\n",
    "print((os.listdir('../input/lyft-motion-prediction-autonomous-vehicles/')))\n",
    "\n",
    "plt.rc('animation', html='jshtml')\n",
    "\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Train Configurations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "DEBUG = False\n",
    "\n",
    "# training cfg\n",
    "training_cfg = {\n",
    "    \n",
    "    'format_version': 4,\n",
    "    \n",
    "     ## Model options\n",
    "    'model_params': {\n",
    "        'model_architecture': 'mixnet_l',\n",
    "        'history_num_frames': 10,\n",
    "        'history_step_size': 1,\n",
    "        'history_delta_time': 0.1,\n",
    "        'future_num_frames': 50,\n",
    "        'future_step_size': 1,\n",
    "        'future_delta_time': 0.1,\n",
    "    },\n",
    "\n",
    "    ## Input raster parameters\n",
    "    'raster_params': {\n",
    "        \n",
    "        'raster_size': [300, 300], # raster's spatial resolution [meters per pixel]: the size in the real world one pixel corresponds to.\n",
    "        'pixel_size': [0.5, 0.5], # From 0 to 1 per axis, [0.5,0.5] would show the ego centered in the image.\n",
    "        'ego_center': [0.25, 0.5],\n",
    "        'map_type': \"py_semantic\",\n",
    "        \n",
    "        # the keys are relative to the dataset environment variable\n",
    "        'satellite_map_key': \"aerial_map/aerial_map.png\",\n",
    "        'semantic_map_key': \"semantic_map/semantic_map.pb\",\n",
    "        'dataset_meta_key': \"meta.json\",\n",
    "\n",
    "        # e.g. 0.0 include every obstacle, 0.5 show those obstacles with >0.5 probability of being\n",
    "        # one of the classes we care about (cars, bikes, peds, etc.), >=1.0 filter all other agents.\n",
    "        'filter_agents_threshold': 0.5,\n",
    "        'disable_traffic_light_faces': False\n",
    "\n",
    "    },\n",
    "\n",
    "    ## Data loader options\n",
    "    'train_data_loader': {\n",
    "        'key': \"scenes/train.zarr\",\n",
    "        'batch_size': 32,\n",
    "        'shuffle': True,\n",
    "        'num_workers': 4\n",
    "    },\n",
    "\n",
    "    ## Train params\n",
    "    'train_params': {\n",
    "        'checkpoint_every_n_steps': 5000,\n",
    "        'max_num_steps': 100 if DEBUG else 10000\n",
    "    }\n",
    "}\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'format_version': 4, 'model_params': {'model_architecture': 'mixnet_l', 'history_num_frames': 10, 'history_step_size': 1, 'history_delta_time': 0.1, 'future_num_frames': 50, 'future_step_size': 1, 'future_delta_time': 0.1}, 'raster_params': {'raster_size': [300, 300], 'pixel_size': [0.5, 0.5], 'ego_center': [0.25, 0.5], 'map_type': 'py_semantic', 'satellite_map_key': 'aerial_map/aerial_map.png', 'semantic_map_key': 'semantic_map/semantic_map.pb', 'dataset_meta_key': 'meta.json', 'filter_agents_threshold': 0.5, 'disable_traffic_light_faces': False}, 'train_data_loader': {'key': 'scenes/train.zarr', 'batch_size': 32, 'shuffle': True, 'num_workers': 4}, 'train_params': {'checkpoint_every_n_steps': 5000, 'max_num_steps': 10000}}\n"
     ]
    }
   ],
   "source": [
    "# root directory\n",
    "DIR_INPUT = \"/kaggle/input/lyft-motion-prediction-autonomous-vehicles\"\n",
    "\n",
    "# set env variable for data\n",
    "os.environ[\"L5KIT_DATA_FOLDER\"] = DIR_INPUT\n",
    "dm = LocalDataManager(None)\n",
    "print(training_cfg)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+------------+------------+------------+---------------+-----------------+----------------------+----------------------+----------------------+---------------------+\n",
      "| Num Scenes | Num Frames | Num Agents | Num TR lights | Total Time (hr) | Avg Frames per Scene | Avg Agents per Frame | Avg Scene Time (sec) | Avg Frame frequency |\n",
      "+------------+------------+------------+---------------+-----------------+----------------------+----------------------+----------------------+---------------------+\n",
      "|   16265    |  4039527   | 320124624  |    38735988   |      112.19     |        248.36        |        79.25         |        24.83         |        10.00        |\n",
      "+------------+------------+------------+---------------+-----------------+----------------------+----------------------+----------------------+---------------------+\n"
     ]
    }
   ],
   "source": [
    "# training cfg\n",
    "train_cfg = training_cfg[\"train_data_loader\"]\n",
    "\n",
    "# rasterizer\n",
    "rasterizer = build_rasterizer(training_cfg, dm)\n",
    "\n",
    "# dataloader\n",
    "train_zarr = ChunkedDataset(dm.require(train_cfg[\"key\"])).open()\n",
    "train_dataset = AgentDataset(training_cfg, train_zarr, rasterizer)\n",
    "train_dataloader = DataLoader(train_dataset, shuffle=train_cfg[\"shuffle\"], batch_size=train_cfg[\"batch_size\"], \n",
    "                             num_workers=train_cfg[\"num_workers\"])\n",
    "print(train_dataset)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Model Definition\n",
    "\n",
    "## LyftModel - simple Resnet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "class LyftModel(nn.Module):\n",
    "    \n",
    "    def __init__(self, cfg):\n",
    "        super().__init__()\n",
    "        \n",
    "        # set pretrained=True while training\n",
    "        self.backbone = resnet50(pretrained=True) \n",
    "        \n",
    "        num_history_channels = (cfg[\"model_params\"][\"history_num_frames\"] + 1) * 2\n",
    "        num_in_channels = 3 + num_history_channels\n",
    "\n",
    "        self.backbone.conv1 = nn.Conv2d(\n",
    "            # num_in_channels = 25\n",
    "            num_in_channels,\n",
    "            self.backbone.conv1.out_channels,\n",
    "            kernel_size=self.backbone.conv1.kernel_size,\n",
    "            stride=self.backbone.conv1.stride,\n",
    "            padding=self.backbone.conv1.padding,\n",
    "            bias=False,\n",
    "        )\n",
    "        \n",
    "        # This is 512 for resnet18 and resnet34;\n",
    "        # And it is 2048 for the other resnets\n",
    "        backbone_out_features = 2048\n",
    "        \n",
    "        # X, Y coords for the future positions (output shape: Bx50x2)\n",
    "        num_targets = 2 * cfg[\"model_params\"][\"future_num_frames\"]\n",
    "\n",
    "        # You can add more layers here.\n",
    "        self.head = nn.Sequential(\n",
    "            # nn.Dropout(0.2),\n",
    "            nn.Linear(in_features=backbone_out_features, out_features=4096),\n",
    "        )\n",
    "\n",
    "        self.logit = nn.Linear(4096, out_features=num_targets)\n",
    "        \n",
    "    def forward(self, x):\n",
    "        x = self.backbone.conv1(x)\n",
    "        x = self.backbone.bn1(x)\n",
    "        x = self.backbone.relu(x)\n",
    "        x = self.backbone.maxpool(x)\n",
    "\n",
    "        x = self.backbone.layer1(x)\n",
    "        x = self.backbone.layer2(x)\n",
    "        x = self.backbone.layer3(x)\n",
    "        x = self.backbone.layer4(x)\n",
    "\n",
    "        x = self.backbone.avgpool(x)\n",
    "        x = torch.flatten(x, 1)\n",
    "        \n",
    "        x = self.head(x)\n",
    "        x = self.logit(x)\n",
    "        \n",
    "        return x"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Mixnet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [],
   "source": [
    "class LyftMixnet(nn.Module):\n",
    "    \n",
    "    def __init__(self, cfg, classify='mixnet_l'):\n",
    "        super().__init__()\n",
    "        \n",
    "        # set pretrained=True while training\n",
    "        self.backbone = timm.create_model(classify, pretrained=True) \n",
    "        \n",
    "        num_history_channels = (cfg[\"model_params\"][\"history_num_frames\"] + 1) * 2\n",
    "        num_in_channels = 3 + num_history_channels\n",
    "        \n",
    "        self.backbone.conv_stem = nn.Conv2d(\n",
    "            # num_in_channels = 25\n",
    "            num_in_channels,\n",
    "            self.backbone.conv_stem.out_channels,\n",
    "            kernel_size=self.backbone.conv_stem.kernel_size,\n",
    "            stride=self.backbone.conv_stem.stride,\n",
    "            padding=self.backbone.conv_stem.padding,\n",
    "            bias=False,\n",
    "        )\n",
    "        \n",
    "        # X, Y coords for the future positions (output shape: Bx50x2)\n",
    "        num_targets = 2 * cfg[\"model_params\"][\"future_num_frames\"]\n",
    "        self.logit = nn.Linear(self.backbone.classifier.out_features, out_features=num_targets)\n",
    "        \n",
    "    def forward(self, x):\n",
    "        x = self.backbone(x)\n",
    "        x = self.logit(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## MobileNet v2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [],
   "source": [
    "class LyftMobile(nn.Module):\n",
    "    \n",
    "    def __init__(self, cfg):\n",
    "        super().__init__()\n",
    "        \n",
    "        # set pretrained=True while training\n",
    "        self.backbone = mobilenet_v2(pretrained=True) \n",
    "        \n",
    "        num_history_channels = (cfg[\"model_params\"][\"history_num_frames\"] + 1) * 2\n",
    "        num_in_channels = 3 + num_history_channels\n",
    "\n",
    "        self.backbone.features[0][0] = nn.Conv2d(\n",
    "            # num_in_channels = 25\n",
    "            num_in_channels,\n",
    "            self.backbone.features[0][0].out_channels,\n",
    "            kernel_size=self.backbone.features[0][0].kernel_size,\n",
    "            stride=self.backbone.features[0][0].stride,\n",
    "            padding=self.backbone.features[0][0].padding,\n",
    "            bias=False,\n",
    "        )\n",
    "                \n",
    "        # X, Y coords for the future positions (output shape: Bx50x2)\n",
    "        num_targets = 2 * cfg[\"model_params\"][\"future_num_frames\"]\n",
    "\n",
    "        # Fully connected layer.\n",
    "        self.backbone.classifier[1] = nn.Linear(\n",
    "            in_features=self.backbone.classifier[1].in_features,\n",
    "            out_features=num_targets\n",
    "        )\n",
    "\n",
    "        \n",
    "    def forward(self, x):\n",
    "        x = self.backbone(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Multi-Modal"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [],
   "source": [
    "# --- Function utils ---\n",
    "# Original code from https://github.com/lyft/l5kit/blob/20ab033c01610d711c3d36e1963ecec86e8b85b6/l5kit/l5kit/evaluation/metrics.py\n",
    "\n",
    "def pytorch_neg_multi_log_likelihood_batch(\n",
    "    gt: Tensor, pred: Tensor, confidences: Tensor, avails: Tensor\n",
    ") -> Tensor:\n",
    "    \"\"\"\n",
    "    Compute a negative log-likelihood for the multi-modal scenario.\n",
    "    log-sum-exp trick is used here to avoid underflow and overflow, For more information about it see:\n",
    "    https://en.wikipedia.org/wiki/LogSumExp#log-sum-exp_trick_for_log-domain_calculations\n",
    "    https://timvieira.github.io/blog/post/2014/02/11/exp-normalize-trick/\n",
    "    https://leimao.github.io/blog/LogSumExp/\n",
    "    Args:\n",
    "        gt (Tensor): array of shape (bs)x(time)x(2D coords)\n",
    "        pred (Tensor): array of shape (bs)x(modes)x(time)x(2D coords)\n",
    "        confidences (Tensor): array of shape (bs)x(modes) with a confidence for each mode in each sample\n",
    "        avails (Tensor): array of shape (bs)x(time) with the availability for each gt timestep\n",
    "    Returns:\n",
    "        Tensor: negative log-likelihood for this example, a single float number\n",
    "    \"\"\"\n",
    "    assert len(pred.shape) == 4, f\"expected 3D (MxTxC) array for pred, got {pred.shape}\"\n",
    "    batch_size, num_modes, future_len, num_coords = pred.shape\n",
    "\n",
    "    assert gt.shape == (batch_size, future_len, num_coords), f\"expected 2D (Time x Coords) array for gt, got {gt.shape}\"\n",
    "    assert confidences.shape == (batch_size, num_modes), f\"expected 1D (Modes) array for gt, got {confidences.shape}\"\n",
    "    if not torch.allclose(torch.sum(confidences, dim=1), confidences.new_ones((batch_size,))):\n",
    "        print(confidences)\n",
    "    assert torch.allclose(torch.sum(confidences, dim=1), confidences.new_ones((batch_size,))), \"confidences should sum to 1\"\n",
    "    assert avails.shape == (batch_size, future_len), f\"expected 1D (Time) array for gt, got {avails.shape}\"\n",
    "    # assert all data are valid\n",
    "    assert torch.isfinite(pred).all(), \"invalid value found in pred\"\n",
    "    assert torch.isfinite(gt).all(), \"invalid value found in gt\"\n",
    "    assert torch.isfinite(confidences).all(), \"invalid value found in confidences\"\n",
    "    assert torch.isfinite(avails).all(), \"invalid value found in avails\"\n",
    "\n",
    "    # convert to (batch_size, num_modes, future_len, num_coords)\n",
    "    gt = torch.unsqueeze(gt, 1)  # add modes\n",
    "    avails = avails[:, None, :, None]  # add modes and cords\n",
    "\n",
    "    # error (batch_size, num_modes, future_len)\n",
    "    error = torch.sum(((gt - pred) * avails) ** 2, dim=-1)  # reduce coords and use availability\n",
    "\n",
    "    with np.errstate(divide=\"ignore\"):  # when confidence is 0 log goes to -inf, but we're fine with it\n",
    "        # error (batch_size, num_modes)\n",
    "        error = torch.log(confidences) - 0.5 * torch.sum(error, dim=-1)  # reduce time\n",
    "\n",
    "    # use max aggregator on modes for numerical stability\n",
    "    # error (batch_size, num_modes)\n",
    "    max_value, _ = error.max(dim=1, keepdim=True)  # error are negative at this point, so max() gives the minimum one\n",
    "    error = -torch.log(torch.sum(torch.exp(error - max_value), dim=-1, keepdim=True)) - max_value  # reduce modes\n",
    "    # print(\"error\", error)\n",
    "    return torch.mean(error)\n",
    "\n",
    "\n",
    "def pytorch_neg_multi_log_likelihood_single(\n",
    "    gt: Tensor, pred: Tensor, avails: Tensor\n",
    ") -> Tensor:\n",
    "    \"\"\"\n",
    "\n",
    "    Args:\n",
    "        gt (Tensor): array of shape (bs)x(time)x(2D coords)\n",
    "        pred (Tensor): array of shape (bs)x(time)x(2D coords)\n",
    "        avails (Tensor): array of shape (bs)x(time) with the availability for each gt timestep\n",
    "    Returns:\n",
    "        Tensor: negative log-likelihood for this example, a single float number\n",
    "    \"\"\"\n",
    "    # pred (bs)x(time)x(2D coords) --> (bs)x(mode=1)x(time)x(2D coords)\n",
    "    # create confidence (bs)x(mode=1)\n",
    "    batch_size, future_len, num_coords = pred.shape\n",
    "    confidences = pred.new_ones((batch_size, 1))\n",
    "    return pytorch_neg_multi_log_likelihood_batch(gt, pred.unsqueeze(1), confidences, avails)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## LSTM layer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Hyperparameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [
    {
     "ename": "RuntimeError",
     "evalue": "CUDA out of memory. Tried to allocate 2.00 MiB (GPU 0; 15.90 GiB total capacity; 15.00 GiB already allocated; 3.75 MiB free; 15.13 GiB reserved in total by PyTorch)",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-37-7a157c8e1516>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     16\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     17\u001b[0m     \u001b[0mWEIGHT_FILE\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m'../input/resnet-34-pth/model_state_mixnetl_25000.pth'\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 18\u001b[0;31m     \u001b[0mmodel_state\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mWEIGHT_FILE\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmap_location\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     19\u001b[0m     \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload_state_dict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel_state\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     20\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/opt/conda/lib/python3.7/site-packages/torch/serialization.py\u001b[0m in \u001b[0;36mload\u001b[0;34m(f, map_location, pickle_module, **pickle_load_args)\u001b[0m\n\u001b[1;32m    592\u001b[0m                     \u001b[0mopened_file\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mseek\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0morig_position\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    593\u001b[0m                     \u001b[0;32mreturn\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjit\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mopened_file\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 594\u001b[0;31m                 \u001b[0;32mreturn\u001b[0m \u001b[0m_load\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mopened_zipfile\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmap_location\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpickle_module\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mpickle_load_args\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    595\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0m_legacy_load\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mopened_file\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmap_location\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpickle_module\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mpickle_load_args\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    596\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/opt/conda/lib/python3.7/site-packages/torch/serialization.py\u001b[0m in \u001b[0;36m_load\u001b[0;34m(zip_file, map_location, pickle_module, pickle_file, **pickle_load_args)\u001b[0m\n\u001b[1;32m    851\u001b[0m     \u001b[0munpickler\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpickle_module\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mUnpickler\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata_file\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mpickle_load_args\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    852\u001b[0m     \u001b[0munpickler\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpersistent_load\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpersistent_load\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 853\u001b[0;31m     \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0munpickler\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    854\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    855\u001b[0m     \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_utils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_validate_loaded_sparse_tensors\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/opt/conda/lib/python3.7/site-packages/torch/serialization.py\u001b[0m in \u001b[0;36mpersistent_load\u001b[0;34m(saved_id)\u001b[0m\n\u001b[1;32m    843\u001b[0m         \u001b[0mdata_type\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkey\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlocation\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msize\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    844\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mkey\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mloaded_storages\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 845\u001b[0;31m             \u001b[0mload_tensor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata_type\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msize\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkey\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_maybe_decode_ascii\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlocation\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    846\u001b[0m         \u001b[0mstorage\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mloaded_storages\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    847\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0mstorage\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/opt/conda/lib/python3.7/site-packages/torch/serialization.py\u001b[0m in \u001b[0;36mload_tensor\u001b[0;34m(data_type, size, key, location)\u001b[0m\n\u001b[1;32m    832\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    833\u001b[0m         \u001b[0mstorage\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mzip_file\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_storage_from_record\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msize\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtype\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstorage\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 834\u001b[0;31m         \u001b[0mloaded_storages\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mrestore_location\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstorage\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlocation\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    835\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    836\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mpersistent_load\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msaved_id\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/opt/conda/lib/python3.7/site-packages/torch/serialization.py\u001b[0m in \u001b[0;36mrestore_location\u001b[0;34m(storage, location)\u001b[0m\n\u001b[1;32m    812\u001b[0m     \u001b[0;32melif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmap_location\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    813\u001b[0m         \u001b[0;32mdef\u001b[0m \u001b[0mrestore_location\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstorage\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlocation\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 814\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mdefault_restore_location\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstorage\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmap_location\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    815\u001b[0m     \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    816\u001b[0m         \u001b[0;32mdef\u001b[0m \u001b[0mrestore_location\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstorage\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlocation\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/opt/conda/lib/python3.7/site-packages/torch/serialization.py\u001b[0m in \u001b[0;36mdefault_restore_location\u001b[0;34m(storage, location)\u001b[0m\n\u001b[1;32m    173\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mdefault_restore_location\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstorage\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlocation\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    174\u001b[0m     \u001b[0;32mfor\u001b[0m \u001b[0m_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfn\u001b[0m \u001b[0;32min\u001b[0m \u001b[0m_package_registry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 175\u001b[0;31m         \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstorage\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlocation\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    176\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    177\u001b[0m             \u001b[0;32mreturn\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/opt/conda/lib/python3.7/site-packages/torch/serialization.py\u001b[0m in \u001b[0;36m_cuda_deserialize\u001b[0;34m(obj, location)\u001b[0m\n\u001b[1;32m    155\u001b[0m                 \u001b[0;32mreturn\u001b[0m \u001b[0mstorage_type\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mobj\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    156\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 157\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mobj\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcuda\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    158\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    159\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/opt/conda/lib/python3.7/site-packages/torch/_utils.py\u001b[0m in \u001b[0;36m_cuda\u001b[0;34m(self, device, non_blocking, **kwargs)\u001b[0m\n\u001b[1;32m     77\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     78\u001b[0m             \u001b[0mnew_type\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mgetattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcuda\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__class__\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__name__\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 79\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mnew_type\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcopy_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_blocking\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     80\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     81\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/opt/conda/lib/python3.7/site-packages/torch/cuda/__init__.py\u001b[0m in \u001b[0;36m_lazy_new\u001b[0;34m(cls, *args, **kwargs)\u001b[0m\n\u001b[1;32m    460\u001b[0m     \u001b[0;31m# We may need to call lazy init again if we are a forked child\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    461\u001b[0m     \u001b[0;31m# del _CudaBase.__new__\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 462\u001b[0;31m     \u001b[0;32mreturn\u001b[0m \u001b[0msuper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_CudaBase\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcls\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__new__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcls\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    463\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    464\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mRuntimeError\u001b[0m: CUDA out of memory. Tried to allocate 2.00 MiB (GPU 0; 15.90 GiB total capacity; 15.00 GiB already allocated; 3.75 MiB free; 15.13 GiB reserved in total by PyTorch)"
     ]
    }
   ],
   "source": [
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "# choose parameter conf\n",
    "conf = 2\n",
    "\n",
    "if conf == 1:\n",
    "    # resnet stuff\n",
    "    model = LyftModel(training_cfg).to(device)\n",
    "    optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=2e-6)\n",
    "    criterion = nn.SmoothL1Loss()\n",
    "    # criterion = nn.MSELoss(reduction=\"none\")\n",
    "\n",
    "if conf == 2:\n",
    "    # mixnet_large\n",
    "\n",
    "    model = LyftMixnet(training_cfg, 'mixnet_l').to(device)\n",
    "    \n",
    "    WEIGHT_FILE = '../input/resnet-34-pth/model_state_mixnetl_25000.pth'\n",
    "    model_state = torch.load(WEIGHT_FILE, map_location=device)\n",
    "    model.load_state_dict(model_state)\n",
    "    \n",
    "    optimizer = optim.Adam(model.parameters(), lr=1e-4, weight_decay=2e-6)\n",
    "    criterion = pytorch_neg_multi_log_likelihood_single\n",
    "    \n",
    "if conf == 3:\n",
    "    # mixnet_medium\n",
    "    model = LyftMixnet(training_cfg, 'mixnet_m').to(device)\n",
    "    optimizer = optim.Adam(model.parameters(), lr=4e-4, weight_decay=2e-6)\n",
    "    criterion = nn.SmoothL1Loss()\n",
    "\n",
    "if conf == 4:\n",
    "    model = LyftMobile(training_cfg).to(device)\n",
    "    optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=2e-6)\n",
    "    criterion = nn.SmoothL1Loss()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Validate Configuration"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "# validate cfg\n",
    "validate_cfg = {\n",
    "    \n",
    "    'format_version': 4,\n",
    "    'model_params': {\n",
    "        'history_num_frames': 10,\n",
    "        'history_step_size': 1,\n",
    "        'history_delta_time': 0.1,\n",
    "        'future_num_frames': 50,\n",
    "        'future_step_size': 1,\n",
    "        'future_delta_time': 0.1\n",
    "    },\n",
    "    \n",
    "    'raster_params': {\n",
    "        'raster_size': [300, 300],\n",
    "        'pixel_size': [0.5, 0.5],\n",
    "        'ego_center': [0.25, 0.5],\n",
    "        'map_type': 'py_semantic',\n",
    "        'satellite_map_key': 'aerial_map/aerial_map.png',\n",
    "        'semantic_map_key': 'semantic_map/semantic_map.pb',\n",
    "        'dataset_meta_key': 'meta.json',\n",
    "        'filter_agents_threshold': 0.5,\n",
    "        'disable_traffic_light_faces': False\n",
    "    },\n",
    "    \n",
    "    'validate_data_loader': {\n",
    "    'key': 'scenes/validate.zarr',\n",
    "    'batch_size': 6,\n",
    "    'shuffle': False,\n",
    "    'num_workers': 4\n",
    "    }\n",
    "\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {
    "collapsed": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/10000 [00:06<?, ?it/s]\n"
     ]
    },
    {
     "ename": "RuntimeError",
     "evalue": "CUDA out of memory. Tried to allocate 28.00 MiB (GPU 0; 15.90 GiB total capacity; 14.97 GiB already allocated; 9.75 MiB free; 15.13 GiB reserved in total by PyTorch)",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-31-d6666f042241>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     19\u001b[0m     \u001b[0mtargets\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"target_positions\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     20\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 21\u001b[0;31m     \u001b[0moutputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreshape\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtargets\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     22\u001b[0m     \u001b[0;31m# use Negative Multi Log Likelihood\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     23\u001b[0m     \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcriterion\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtargets\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtarget_availabilities\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m    725\u001b[0m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    726\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 727\u001b[0;31m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    728\u001b[0m         for hook in itertools.chain(\n\u001b[1;32m    729\u001b[0m                 \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m<ipython-input-20-4fa415c88663>\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m     25\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     26\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 27\u001b[0;31m         \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackbone\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     28\u001b[0m         \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlogit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     29\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m    725\u001b[0m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    726\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 727\u001b[0;31m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    728\u001b[0m         for hook in itertools.chain(\n\u001b[1;32m    729\u001b[0m                 \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/opt/conda/lib/python3.7/site-packages/timm/models/efficientnet.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m    389\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    390\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 391\u001b[0;31m         \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward_features\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    392\u001b[0m         \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mglobal_pool\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    393\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdrop_rate\u001b[0m \u001b[0;34m>\u001b[0m \u001b[0;36m0.\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/opt/conda/lib/python3.7/site-packages/timm/models/efficientnet.py\u001b[0m in \u001b[0;36mforward_features\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m    382\u001b[0m         \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbn1\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    383\u001b[0m         \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mact1\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 384\u001b[0;31m         \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mblocks\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    385\u001b[0m         \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconv_head\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    386\u001b[0m         \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbn2\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m    725\u001b[0m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    726\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 727\u001b[0;31m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    728\u001b[0m         for hook in itertools.chain(\n\u001b[1;32m    729\u001b[0m                 \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/opt/conda/lib/python3.7/site-packages/torch/nn/modules/container.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m    115\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    116\u001b[0m         \u001b[0;32mfor\u001b[0m \u001b[0mmodule\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 117\u001b[0;31m             \u001b[0minput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodule\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    118\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    119\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m    725\u001b[0m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    726\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 727\u001b[0;31m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    728\u001b[0m         for hook in itertools.chain(\n\u001b[1;32m    729\u001b[0m                 \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/opt/conda/lib/python3.7/site-packages/torch/nn/modules/container.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m    115\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    116\u001b[0m         \u001b[0;32mfor\u001b[0m \u001b[0mmodule\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 117\u001b[0;31m             \u001b[0minput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodule\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    118\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    119\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m    725\u001b[0m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    726\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 727\u001b[0;31m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    728\u001b[0m         for hook in itertools.chain(\n\u001b[1;32m    729\u001b[0m                 \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/opt/conda/lib/python3.7/site-packages/timm/models/efficientnet_blocks.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m    262\u001b[0m         \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconv_dw\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    263\u001b[0m         \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbn2\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 264\u001b[0;31m         \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mact2\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    265\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    266\u001b[0m         \u001b[0;31m# Squeeze-and-excitation\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m    725\u001b[0m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    726\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 727\u001b[0;31m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    728\u001b[0m         for hook in itertools.chain(\n\u001b[1;32m    729\u001b[0m                 \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/opt/conda/lib/python3.7/site-packages/torch/nn/modules/activation.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m    392\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    393\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 394\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msilu\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minplace\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minplace\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    395\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    396\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mextra_repr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/opt/conda/lib/python3.7/site-packages/torch/nn/functional.py\u001b[0m in \u001b[0;36msilu\u001b[0;34m(input, inplace)\u001b[0m\n\u001b[1;32m   1738\u001b[0m             \u001b[0;32mreturn\u001b[0m \u001b[0mhandle_torch_function\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msilu\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minplace\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minplace\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1739\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0minplace\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1740\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_C\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_nn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msilu_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1741\u001b[0m     \u001b[0;32mreturn\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_C\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_nn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msilu\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1742\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mRuntimeError\u001b[0m: CUDA out of memory. Tried to allocate 28.00 MiB (GPU 0; 15.90 GiB total capacity; 14.97 GiB already allocated; 9.75 MiB free; 15.13 GiB reserved in total by PyTorch)"
     ]
    }
   ],
   "source": [
    "tr_it = iter(train_dataloader)\n",
    "progress_bar = tqdm(range(training_cfg[\"train_params\"][\"max_num_steps\"]))\n",
    "\n",
    "losses_train = []\n",
    "likelihood_valid = []\n",
    "\n",
    "for iter_step in progress_bar:\n",
    "    try:\n",
    "        data = next(tr_it)\n",
    "    except StopIteration:\n",
    "        tr_it = iter(train_dataloader)\n",
    "        data = next(tr_it)\n",
    "    model.train()\n",
    "    torch.set_grad_enabled(True)\n",
    "    \n",
    "    # forward pass\n",
    "    inputs = data[\"image\"].to(device)\n",
    "    target_availabilities = data[\"target_availabilities\"].unsqueeze(-1).to(device)\n",
    "    targets = data[\"target_positions\"].to(device)\n",
    "    \n",
    "    outputs = model(inputs).reshape(targets.shape).to(device)\n",
    "    # use Negative Multi Log Likelihood\n",
    "    loss = criterion(targets, outputs, target_availabilities.squeeze(-1))\n",
    "\n",
    "#     loss = criterion(outputs, targets)\n",
    "#     # not all the output steps are valid, but we can filter them out from the loss using availabilities\n",
    "#     loss = loss * target_availabilities\n",
    "#     loss = loss.mean()\n",
    "\n",
    "    # Backward pass\n",
    "    optimizer.zero_grad()\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "\n",
    "    losses_train.append(loss.item())\n",
    "        \n",
    "    progress_bar.set_description(f\"loss: {loss.item()} loss(avg): {np.mean(losses_train)}\")\n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Loss Plot & NLL Plot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Training Loss plot\n",
    "l = [i for i in range(1, len(losses_train)+1)]\n",
    "# new_ticks=np.linspace(0,training_cfg[\"train_params\"][\"max_num_steps\"],6)\n",
    "plt.plot(l, losses_train,label=\"Training Loss\")\n",
    "# plt.xticks(new_ticks)\n",
    "plt.title(\"Loss Performance Versus Scenes\")\n",
    "plt.rcParams['figure.figsize'] = 10, 10\n",
    "plt.xlabel(\"Scenes\")\n",
    "plt.ylabel(\"Loss\")\n",
    "plt.savefig(\"loss_mixnet_l_45000.png\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Validation Negative Log Likelihood Plot\n",
    "# l = [i for i in range(len(likelihood_valid))]\n",
    "# plt.plot(l, likelihood_valid,label=\"NLL\")\n",
    "# plt.xticks(l, [str(i * 1000) for i in l])\n",
    "# plt.title(\"Validation Negative Log Likelihood Versus Scenes\")\n",
    "# plt.rcParams['figure.figsize'] = 10, 10\n",
    "# plt.xlabel(\"Scenes\")\n",
    "# plt.ylabel(\"NLL\")\n",
    "# plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Save Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# save full trained model\n",
    "model_name = \"model_state_mixnetl_45000.pth\"\n",
    "\n",
    "torch.save(model.state_dict(), model_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "os.chdir(r'/kaggle/working')\n",
    "FileLink(model_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "FileLink(\"loss_mixnet_l_45000.png\")\n",
    "# gc.collect()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# # validation configuration\n",
    "# valid_cfg = validate_cfg[\"validate_data_loader\"]\n",
    "\n",
    "# # Rasterizer\n",
    "# rasterizer = build_rasterizer(validate_cfg, dm)\n",
    "\n",
    "# # Validation dataset/dataloader\n",
    "# valid_zarr = ChunkedDataset(dm.require(valid_cfg[\"key\"])).open()\n",
    "# # valid_mask = np.load(f\"{DIR_INPUT}/scenes/mask.npz\")[\"arr_0\"]\n",
    "# valid_dataset = AgentDataset(validate_cfg, valid_zarr, rasterizer)\n",
    "# whole_size = valid_dataset.__len__()\n",
    "# valid_dataset_use, valid_dataset_valid, _ = torch.utils.data.random_split(valid_dataset, [70000, 2000, whole_size-72000], generator=torch.Generator().manual_seed(42))\n",
    "\n",
    "# valid_dataloader = DataLoader(valid_dataset_use,\n",
    "#                              shuffle=valid_cfg[\"shuffle\"],\n",
    "#                              batch_size=valid_cfg[\"batch_size\"],\n",
    "#                              num_workers=valid_cfg[\"num_workers\"])\n",
    "\n",
    "# valid_dataloader_valid = DataLoader(valid_dataset_valid,\n",
    "#                              shuffle=valid_cfg[\"shuffle\"],\n",
    "#                              batch_size=valid_cfg[\"batch_size\"],\n",
    "#                              num_workers=valid_cfg[\"num_workers\"])\n",
    "\n",
    "# print(valid_dataloader.dataset.__len__())\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# # Saved state dict from the training notebook\n",
    "# device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "# WEIGHT_FILE = '../input/resnet-34-pth/model_state_mixnetl_15000.pth'\n",
    "# model_state = torch.load(WEIGHT_FILE, map_location=device)\n",
    "# model.load_state_dict(model_state)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Final Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# # Final - Evaluate Validation Dataset \n",
    "# model.eval()\n",
    "# torch.set_grad_enabled(False)\n",
    "\n",
    "# # store information for evaluation\n",
    "# future_coords_offsets_pd = []\n",
    "# timestamps = []\n",
    "# # coordinates ground truth\n",
    "# valid_coords_gts = []\n",
    "# # target avalabilities\n",
    "# target_avail_pd = []\n",
    "# agent_ids = []\n",
    "# progress_bar = tqdm(valid_dataloader)\n",
    "# for data in progress_bar:\n",
    "    \n",
    "#     inputs = data[\"image\"].to(device)\n",
    "#     target_availabilities = data[\"target_availabilities\"].unsqueeze(-1).to(device)\n",
    "#     targets = data[\"target_positions\"].to(device)\n",
    "    \n",
    "#     outputs = model(inputs).reshape(targets.shape)\n",
    "    \n",
    "#     future_coords_offsets_pd.append(outputs.cpu().numpy().copy())\n",
    "#     timestamps.append(data[\"timestamp\"].numpy().copy())\n",
    "#     agent_ids.append(data[\"track_id\"].numpy().copy())\n",
    "#     valid_coords_gts.append(data[\"target_positions\"].numpy().copy())\n",
    "#     target_avail_pd.append(target_availabilities.cpu().numpy().copy())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "_kg_hide-output": true
   },
   "outputs": [],
   "source": [
    "# timestamps_concat = np.concatenate(timestamps)\n",
    "# track_ids_concat = np.concatenate(agent_ids)\n",
    "# coords_concat = np.concatenate(future_coords_offsets_pd)\n",
    "# gt_valid_final = np.concatenate(valid_coords_gts)\n",
    "# target_avail_concat = np.concatenate(target_avail_pd)\n",
    "\n",
    "# # gt_valid_2D = np.reshape(gt_valid_final, (70000, 100))\n",
    "\n",
    "# # log_like = neg_multi_log_likelihood (\n",
    "# #     ground_truth=gt_valid_2D,\n",
    "# #     pred=coords_concat,\n",
    "# #     confidences=np.array([1,0,0]),\n",
    "# #     avails=target_avail_concat\n",
    "# # )\n",
    "\n",
    "# # print(\"Negative multi Likelihood is:\", log_like)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Ground Truth and Prediction CSV"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# # generate ground truth csv\n",
    "# write_gt_csv(\n",
    "#     csv_path=\"valid_gt.csv\", \n",
    "#     timestamps=timestamps_concat, \n",
    "#     track_ids=track_ids_concat, \n",
    "#     coords=gt_valid_final, \n",
    "#     avails=target_avail_concat.squeeze(-1)\n",
    "# )\n",
    "\n",
    "# # submission.csv\n",
    "# write_pred_csv('submission_mixnetl_15000.csv',\n",
    "#                timestamps=timestamps_concat,\n",
    "#                track_ids=track_ids_concat,\n",
    "#                coords=coords_concat,\n",
    "#               )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load CSV & Compute metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# # Negative Log Likelihood Metrics\n",
    "# eval_gt_path = \"../input/resnet-34-pth/valid_gt.csv\"\n",
    "# pred_path = \"submission_mixnetl_15000.csv\"\n",
    "\n",
    "# metrics = compute_metrics_csv(eval_gt_path, pred_path, [neg_multi_log_likelihood, time_displace])\n",
    "# for metric_name, metric_mean in metrics.items():\n",
    "#     print(metric_name, metric_mean)\n",
    "    \n",
    "# # Save Metric\n",
    "# np.save('metric_mixnetl_15000.npy',metrics)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Plot Prediction Tractories"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# model.eval()\n",
    "# torch.set_grad_enabled(False)\n",
    "\n",
    "# # Uncomment to choose satelliter or semantic rasterizer\n",
    "# validate_cfg[\"raster_params\"][\"map_type\"] = \"py_satellite\"\n",
    "# # validate_cfg[\"raster_params\"][\"map_type\"] = \"py_semantic\"\n",
    "\n",
    "\n",
    "# rast = build_rasterizer(validate_cfg, dm)\n",
    "\n",
    "# eval_ego_dataset = EgoDataset(validate_cfg, valid_dataset.dataset, rast)\n",
    "# num_frames = 2 # randomly pick _ frames\n",
    "# random_frames = np.random.randint(0,len(eval_ego_dataset)-1, (num_frames,))\n",
    "\n",
    "# for frame_number in random_frames:  \n",
    "#     agent_indices = valid_dataset.get_frame_indices(frame_number) \n",
    "#     if not len(agent_indices):\n",
    "#         continue\n",
    "\n",
    "#     # get AV point-of-view frame\n",
    "#     data_ego = eval_ego_dataset[frame_number]\n",
    "#     im_ego = rasterizer.to_rgb(data_ego[\"image\"].transpose(1, 2, 0))\n",
    "#     center = np.asarray(validate_cfg[\"raster_params\"][\"ego_center\"]) * validate_cfg[\"raster_params\"][\"raster_size\"]\n",
    "    \n",
    "#     predicted_positions = []\n",
    "#     target_positions = []\n",
    "\n",
    "#     for v_index in agent_indices:\n",
    "#         data_agent = valid_dataset[v_index]\n",
    "\n",
    "#         out_net = model(torch.from_numpy(data_agent[\"image\"]).unsqueeze(0).to(device))\n",
    "#         out_pos = out_net[0].reshape(-1, 2).detach().cpu().numpy()\n",
    "#         # store absolute world coordinates\n",
    "#         predicted_positions.append(transform_points(out_pos, data_agent[\"world_from_agent\"]))\n",
    "#         # retrieve target positions from the GT and store as absolute coordinates\n",
    "#         track_id, timestamp = data_agent[\"track_id\"], data_agent[\"timestamp\"]\n",
    "#         target_positions.append(transform_points(data_agent[\"target_positions\"], data_agent[\"world_from_agent\"]) )\n",
    "\n",
    "#     # convert coordinates to AV point-of-view so we can draw them\n",
    "#     predicted_positions = transform_points(np.concatenate(predicted_positions), data_ego[\"raster_from_world\"])\n",
    "#     target_positions = transform_points(np.concatenate(target_positions), data_ego[\"raster_from_world\"])\n",
    "    \n",
    "#     # make sure ground truth and prediction have the same data size\n",
    "#     assert len(target_positions) == len(predicted_positions)\n",
    "    \n",
    "#     # draw_trajectory(im_ego, predicted_positions, PREDICTED_POINTS_COLOR)\n",
    "#     draw_trajectory(im_ego, target_positions, TARGET_POINTS_COLOR)\n",
    "    \n",
    "#     plt.rcParams['figure.figsize'] = 6, 6\n",
    "#     plt.imshow(im_ego[::-1])\n",
    "#     plt.show()\n",
    "    \n",
    "#     draw_trajectory(im_ego, predicted_positions, PREDICTED_POINTS_COLOR)\n",
    "    \n",
    "#     plt.rcParams['figure.figsize'] = 6, 6\n",
    "#     plt.imshow(im_ego[::-1])\n",
    "#     plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.10"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": true
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
